/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.streaming.tests;

import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.state.CheckpointListener;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.contrib.streaming.state.EmbeddedRocksDBStateBackend;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.hashmap.HashMapStateBackend;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.environment.CheckpointConfig;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.PrintSinkFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

import org.apache.commons.lang3.RandomStringUtils;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

/**
 * Automatic end-to-end test for local recovery (including sticky allocation).
 *
 * <p>List of possible input parameters for this job:
 *
 * <ul>
 *   <li>checkpointDir: the checkpoint directory, required.
 *   <li>parallelism: the parallelism of the job, default 1.
 *   <li>maxParallelism: the maximum parallelism of the job, default 1.
 *   <li>checkpointInterval: the checkpointing interval in milliseconds, default 1000.
 *   <li>restartDelay: the delay of the fixed delay restart strategy, default 0.
 *   <li>externalizedCheckpoints: flag to activate externalized checkpoints, default <code>false
 *       </code>.
 *   <li>stateBackend: choice for state backend between <code>file</code> and <code>rocks</code>,
 *       default <code>file</code>.
 *   <li>killJvmOnFail: flag that determines whether or not an artificial failure induced by the
 *       test kills the JVM or not.
 *   <li>incrementalCheckpoints: flag for incremental checkpoint with rocks state backend, default
 *       <code>false</code>.
 *   <li>delay: sleep delay to throttle down the production of the source, default 0.
 *   <li>maxAttempts: the maximum number of run attempts, before the job finishes with success,
 *       default 3.
 *   <li>valueSize: size of the artificial value for each key in bytes, default 10.
 * </ul>
 */
public class StickyAllocationAndLocalRecoveryTestJob {

    public static void main(String[] args) throws Exception {

        final ParameterTool pt = ParameterTool.fromArgs(args);

        final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();

        env.setParallelism(pt.getInt("parallelism", 1));
        env.setMaxParallelism(pt.getInt("maxParallelism", pt.getInt("parallelism", 1)));
        env.enableCheckpointing(pt.getInt("checkpointInterval", 1000));
        env.setRestartStrategy(
                RestartStrategies.fixedDelayRestart(
                        Integer.MAX_VALUE, pt.getInt("restartDelay", 0)));
        if (pt.getBoolean("externalizedCheckpoints", false)) {
            env.getCheckpointConfig()
                    .setExternalizedCheckpointCleanup(
                            CheckpointConfig.ExternalizedCheckpointCleanup.RETAIN_ON_CANCELLATION);
        }

        String checkpointDir = pt.getRequired("checkpointDir");
        env.getCheckpointConfig().setCheckpointStorage(checkpointDir);

        boolean killJvmOnFail = pt.getBoolean("killJvmOnFail", false);

        String stateBackend = pt.get("stateBackend", "hashmap");
        if ("hashmap".equals(stateBackend)) {
            env.setStateBackend(new HashMapStateBackend());
        } else if ("rocks".equals(stateBackend)) {
            boolean incrementalCheckpoints = pt.getBoolean("incrementalCheckpoints", false);
            env.setStateBackend(new EmbeddedRocksDBStateBackend(incrementalCheckpoints));
        } else {
            throw new IllegalArgumentException("Unknown backend: " + stateBackend);
        }

        // make parameters available in the web interface
        env.getConfig().setGlobalJobParameters(pt);

        // delay to throttle down the production of the source
        long delay = pt.getLong("delay", 0L);

        // the maximum number of attempts, before the job finishes with success
        int maxAttempts = pt.getInt("maxAttempts", 3);

        // size of one artificial value
        int valueSize = pt.getInt("valueSize", 10);

        env.addSource(new RandomLongSource(maxAttempts, delay))
                .keyBy((KeySelector<Long, Long>) aLong -> aLong)
                .flatMap(new StateCreatingFlatMap(valueSize, killJvmOnFail))
                .addSink(new PrintSinkFunction<>());

        env.execute("Sticky Allocation And Local Recovery Test");
    }

    /** Source function that produces a long sequence. */
    private static final class RandomLongSource extends RichParallelSourceFunction<Long>
            implements CheckpointedFunction {

        private static final long serialVersionUID = 1L;

        /** Generator delay between two events. */
        final long delay;

        /** Maximum restarts before shutting down this source. */
        final int maxAttempts;

        /** State that holds the current key for recovery. */
        transient ListState<Long> sourceCurrentKeyState;

        /** Generator's current key. */
        long currentKey;

        /** Generator runs while this is true. */
        volatile boolean running;

        RandomLongSource(int maxAttempts, long delay) {
            this.delay = delay;
            this.maxAttempts = maxAttempts;
            this.running = true;
        }

        @Override
        public void run(SourceContext<Long> sourceContext) throws Exception {

            int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
            int subtaskIdx = getRuntimeContext().getIndexOfThisSubtask();

            // the source emits one final event and shuts down once we have reached max attempts.
            if (getRuntimeContext().getAttemptNumber() > maxAttempts) {
                synchronized (sourceContext.getCheckpointLock()) {
                    sourceContext.collect(Long.MAX_VALUE - subtaskIdx);
                }
                return;
            }

            while (running) {

                synchronized (sourceContext.getCheckpointLock()) {
                    sourceContext.collect(currentKey);
                    currentKey += numberOfParallelSubtasks;
                }

                if (delay > 0) {
                    Thread.sleep(delay);
                }
            }
        }

        @Override
        public void cancel() {
            running = false;
        }

        @Override
        public void snapshotState(FunctionSnapshotContext context) throws Exception {
            sourceCurrentKeyState.clear();
            sourceCurrentKeyState.add(currentKey);
        }

        @Override
        public void initializeState(FunctionInitializationContext context) throws Exception {

            ListStateDescriptor<Long> currentKeyDescriptor =
                    new ListStateDescriptor<>("currentKey", Long.class);
            sourceCurrentKeyState =
                    context.getOperatorStateStore().getListState(currentKeyDescriptor);

            currentKey = getRuntimeContext().getIndexOfThisSubtask();
            Iterable<Long> iterable = sourceCurrentKeyState.get();
            if (iterable != null) {
                Iterator<Long> iterator = iterable.iterator();
                if (iterator.hasNext()) {
                    currentKey = iterator.next();
                    Preconditions.checkState(!iterator.hasNext());
                }
            }
        }
    }

    /** Stateful map function. Failure creation and checks happen here. */
    private static final class StateCreatingFlatMap extends RichFlatMapFunction<Long, String>
            implements CheckpointedFunction, CheckpointListener {

        private static final long serialVersionUID = 1L;

        /** User configured size of the generated artificial values in the keyed state. */
        final int valueSize;

        /** Holds the user configuration if the artificial test failure is killing the JVM. */
        final boolean killTaskOnFailure;

        /** This state is used to create artificial keyed state in the backend. */
        transient ValueState<String> valueState;

        /** This state is used to persist the schedulingAndFailureInfo to state. */
        transient ListState<MapperSchedulingAndFailureInfo> schedulingAndFailureState;

        /** This contains the current scheduling and failure meta data. */
        transient MapperSchedulingAndFailureInfo currentSchedulingAndFailureInfo;

        /** Message to indicate that recovery detected a failure with sticky allocation. */
        transient volatile String allocationFailureMessage;

        /**
         * If this flag is true, the next invocation of the map function introduces a test failure.
         */
        transient volatile boolean failTask;

        StateCreatingFlatMap(int valueSize, boolean killTaskOnFailure) {
            this.valueSize = valueSize;
            this.failTask = false;
            this.killTaskOnFailure = killTaskOnFailure;
            this.allocationFailureMessage = null;
        }

        @Override
        public void flatMap(Long key, Collector<String> collector) throws IOException {

            if (allocationFailureMessage != null) {
                // Report the failure downstream, so that we can get the message from the output.
                collector.collect(allocationFailureMessage);
                allocationFailureMessage = null;
            }

            if (failTask) {
                // we fail the task, either by killing the JVM hard, or by throwing a user code
                // exception.
                if (killTaskOnFailure) {
                    Runtime.getRuntime().halt(-1);
                } else {
                    throw new RuntimeException("Artificial user code exception.");
                }
            }

            // sanity check
            if (null != valueState.value()) {
                throw new IllegalStateException(
                        "This should never happen, keys are generated monotonously.");
            }

            // store artificial data to blow up the state
            valueState.update(RandomStringUtils.random(valueSize, true, true));
        }

        @Override
        public void snapshotState(FunctionSnapshotContext functionSnapshotContext) {}

        @Override
        public void initializeState(FunctionInitializationContext functionInitializationContext)
                throws Exception {
            ValueStateDescriptor<String> stateDescriptor =
                    new ValueStateDescriptor<>("state", String.class);
            valueState =
                    functionInitializationContext.getKeyedStateStore().getState(stateDescriptor);

            ListStateDescriptor<MapperSchedulingAndFailureInfo> mapperInfoStateDescriptor =
                    new ListStateDescriptor<>("mapperState", MapperSchedulingAndFailureInfo.class);
            schedulingAndFailureState =
                    functionInitializationContext
                            .getOperatorStateStore()
                            .getUnionListState(mapperInfoStateDescriptor);

            StreamingRuntimeContext runtimeContext = (StreamingRuntimeContext) getRuntimeContext();
            String allocationID = runtimeContext.getAllocationIDAsString();
            // Pattern of the name: "Flat Map -> Sink: Unnamed (4/4)#0". Remove "#0" part:
            String taskNameWithSubtasks = runtimeContext.getTaskNameWithSubtasks().split("#")[0];

            final int thisJvmPid = getJvmPid();
            final Set<Integer> killedJvmPids = new HashSet<>();

            // here we check if the sticky scheduling worked as expected
            if (functionInitializationContext.isRestored()) {
                Iterable<MapperSchedulingAndFailureInfo> iterable = schedulingAndFailureState.get();

                MapperSchedulingAndFailureInfo infoForThisTask = null;
                List<MapperSchedulingAndFailureInfo> completeInfo = new ArrayList<>();
                if (iterable != null) {
                    for (MapperSchedulingAndFailureInfo testInfo : iterable) {

                        completeInfo.add(testInfo);

                        if (taskNameWithSubtasks.equals(testInfo.taskNameWithSubtask)) {
                            infoForThisTask = testInfo;
                        }

                        if (testInfo.killedJvm) {
                            killedJvmPids.add(testInfo.jvmPid);
                        }
                    }
                }

                Preconditions.checkNotNull(infoForThisTask, "Expected to find info here.");

                if (!isScheduledToCorrectAllocation(infoForThisTask, allocationID, killedJvmPids)) {
                    allocationFailureMessage =
                            String.format(
                                    "Sticky allocation test failed: Subtask %s in attempt %d was rescheduled from allocation %s "
                                            + "on JVM with PID %d to unexpected allocation %s on JVM with PID %d.\n"
                                            + "Complete information from before the crash: %s.",
                                    taskNameWithSubtasks,
                                    runtimeContext.getAttemptNumber(),
                                    infoForThisTask.allocationId,
                                    infoForThisTask.jvmPid,
                                    allocationID,
                                    thisJvmPid,
                                    completeInfo);
                }
            }

            // We determine which of the subtasks will produce the artificial failure
            boolean failingTask = shouldTaskFailForThisAttempt();

            // We take note of all the meta info that we require to check sticky scheduling in the
            // next re-attempt
            this.currentSchedulingAndFailureInfo =
                    new MapperSchedulingAndFailureInfo(
                            failingTask,
                            failingTask && killTaskOnFailure,
                            thisJvmPid,
                            taskNameWithSubtasks,
                            allocationID);

            schedulingAndFailureState.clear();
            schedulingAndFailureState.add(currentSchedulingAndFailureInfo);
        }

        @Override
        public void notifyCheckpointComplete(long checkpointId) {
            // we can only fail the task after at least one checkpoint is completed to record
            // progress.
            failTask = currentSchedulingAndFailureInfo.failingTask;
        }

        @Override
        public void notifyCheckpointAborted(long checkpointId) {}

        private boolean shouldTaskFailForThisAttempt() {
            RuntimeContext runtimeContext = getRuntimeContext();
            int numSubtasks = runtimeContext.getNumberOfParallelSubtasks();
            int subtaskIdx = runtimeContext.getIndexOfThisSubtask();
            int attempt = runtimeContext.getAttemptNumber();
            return (attempt % numSubtasks) == subtaskIdx;
        }

        private boolean isScheduledToCorrectAllocation(
                MapperSchedulingAndFailureInfo infoForThisTask,
                String allocationID,
                Set<Integer> killedJvmPids) {

            return (infoForThisTask.allocationId.equals(allocationID)
                    || killedJvmPids.contains(infoForThisTask.jvmPid));
        }
    }

    /**
     * This code is copied from Stack Overflow.
     *
     * <p><a
     * href="https://stackoverflow.com/questions/35842">https://stackoverflow.com/questions/35842</a>,
     * answer <a
     * href="https://stackoverflow.com/a/12066696/9193881">https://stackoverflow.com/a/12066696/9193881</a>
     *
     * <p>Author: <a href="https://stackoverflow.com/users/446591/brad-mace">Brad Mace</a>)
     */
    private static int getJvmPid() throws Exception {
        java.lang.management.RuntimeMXBean runtime =
                java.lang.management.ManagementFactory.getRuntimeMXBean();
        java.lang.reflect.Field jvm = runtime.getClass().getDeclaredField("jvm");
        jvm.setAccessible(true);
        sun.management.VMManagement mgmt = (sun.management.VMManagement) jvm.get(runtime);
        java.lang.reflect.Method pidMethod = mgmt.getClass().getDeclaredMethod("getProcessId");
        pidMethod.setAccessible(true);

        return (int) (Integer) pidMethod.invoke(mgmt);
    }

    /** Records the information required to check sticky scheduling after a restart. */
    public static class MapperSchedulingAndFailureInfo implements Serializable {

        private static final long serialVersionUID = 1L;

        /** True iff this task inflicts a test failure. */
        final boolean failingTask;

        /** True iff this task kills its JVM. */
        final boolean killedJvm;

        /** PID of the task JVM. */
        final int jvmPid;

        /** Name and subtask index of the task. */
        final String taskNameWithSubtask;

        /** The current allocation id of this task. */
        final String allocationId;

        MapperSchedulingAndFailureInfo(
                boolean failingTask,
                boolean killedJvm,
                int jvmPid,
                String taskNameWithSubtask,
                String allocationId) {

            this.failingTask = failingTask;
            this.killedJvm = killedJvm;
            this.jvmPid = jvmPid;
            this.taskNameWithSubtask = taskNameWithSubtask;
            this.allocationId = allocationId;
        }

        @Override
        public String toString() {
            return "MapperTestInfo{"
                    + "failingTask="
                    + failingTask
                    + ", killedJvm="
                    + killedJvm
                    + ", jvmPid="
                    + jvmPid
                    + ", taskNameWithSubtask='"
                    + taskNameWithSubtask
                    + '\''
                    + ", allocationId='"
                    + allocationId
                    + '\''
                    + '}';
        }
    }
}
